- View on Pages (Directly click the title)
- View on Jupyter Notebook
- Download Jupyter Notebook
Import packages
|
|
Loading training data from sklearn
如果需要使用 sklearn 第三方库中自带的数据集,这里我列出了三种,方便调用与测试
sklearn的安装
调用方法
- data = load_data(name=’moons’)
- data = load_data(name=’circles’)
- data = load_data(name=’linear’)
如果安装sklearn有困难,也可以直接从文件读取:
- data = load_data(name=’moons’, True)
- data = load_data(name=’circles’, True)
- data = load_data(name=’linear’, True)
|
|
Define network
此处给出的是逻辑回归(Logistic Regression)的神经网络结构
对于输入向量x,其属于类别i的概率为:
$$\begin{align}
P\left ( Y=i\mid x,W,b \right ) &=softmax_i\left ( Wx+b \right ) \
&= \frac{e^{W_ix+bi}}{\sum{j}e^{W_jx+b_j}}
\end{align}$$
模型对于输入向量x的预测结果ypred是所有类别的预测中概率值最大的,即
$$y{pred}=argmax_iP\left ( Y=i\mid x,W,b \right )$$
在LR模型中,需要求解的参数为权重矩阵W和偏置向量b,为了求解模型的两个参数,首先必须定义损失函数。对于上述的多类别Logistic回归,可以藉由Log似然函数作为其损失函数(负对数似然 注意取负):
$$L\left ( \theta =\left { W,b \right },D \right )=\sum_{i=0}^{\left | D \right |}log\left ( P\left ( Y=y^{\left ( i \right )}\mid x^{\left ( i \right )},W,b \right ) \right )$$
P.S. 代码中使用的softmax_cross_entropy(y_truth, logits)
效果为:
先将logits
作softmax
操作获得y_pred
,然后使用y_truth
和y_pred
作负对数似然。
|
|
Define optimizer
因为深度学习常见的是对于梯度的优化,也就是说,
优化器最后其实就是各种对于梯度下降算法的优化。
常见的优化器有 SGD,RMSprop,Adagrad,Adadelta,Adam 等,
此处实例中使用的是随机梯度下降(Stochastic gradient descent),
因为大多数机器学习任务就是最小化损失,在损失定义的情况下,后面的工作就交给优化器处理即可
|
|
Start Training
对于模型选择输入数据(data),设置参数(迭代次数,每N次绘制一次)
|
|
epoch is 0, train error 0.513333, test error 0.460000
epoch is 1, train error 0.513333, test error 0.165000
epoch is 2, train error 0.248333, test error 0.160000
epoch is 3, train error 0.201667, test error 0.175000
epoch is 4, train error 0.215000, test error 0.185000
epoch is 5, train error 0.206667, test error 0.185000
epoch is 6, train error 0.200000, test error 0.185000
epoch is 7, train error 0.200000, test error 0.185000
epoch is 8, train error 0.201667, test error 0.190000
epoch is 9, train error 0.203333, test error 0.190000
epoch is 10, train error 0.198333, test error 0.195000
epoch is 11, train error 0.198333, test error 0.195000
epoch is 12, train error 0.201667, test error 0.195000
epoch is 13, train error 0.201667, test error 0.195000
epoch is 14, train error 0.198333, test error 0.190000
epoch is 15, train error 0.198333, test error 0.190000
epoch is 16, train error 0.198333, test error 0.190000
epoch is 17, train error 0.195000, test error 0.190000
epoch is 18, train error 0.195000, test error 0.185000
epoch is 19, train error 0.195000, test error 0.185000
epoch is 20, train error 0.193333, test error 0.180000
epoch is 21, train error 0.191667, test error 0.180000
epoch is 22, train error 0.191667, test error 0.180000
epoch is 23, train error 0.191667, test error 0.180000
epoch is 24, train error 0.190000, test error 0.180000
epoch is 25, train error 0.191667, test error 0.175000
epoch is 26, train error 0.193333, test error 0.170000
epoch is 27, train error 0.193333, test error 0.170000
epoch is 28, train error 0.191667, test error 0.165000
epoch is 29, train error 0.191667, test error 0.165000
epoch is 30, train error 0.193333, test error 0.160000
epoch is 31, train error 0.193333, test error 0.160000
epoch is 32, train error 0.193333, test error 0.160000
epoch is 33, train error 0.193333, test error 0.160000
epoch is 34, train error 0.193333, test error 0.160000
epoch is 35, train error 0.193333, test error 0.160000
epoch is 36, train error 0.193333, test error 0.160000
epoch is 37, train error 0.195000, test error 0.160000
epoch is 38, train error 0.195000, test error 0.160000
epoch is 39, train error 0.195000, test error 0.155000
epoch is 40, train error 0.195000, test error 0.155000
epoch is 41, train error 0.195000, test error 0.155000
epoch is 42, train error 0.193333, test error 0.155000
epoch is 43, train error 0.190000, test error 0.160000
epoch is 44, train error 0.190000, test error 0.160000
epoch is 45, train error 0.188333, test error 0.160000
epoch is 46, train error 0.188333, test error 0.160000
epoch is 47, train error 0.188333, test error 0.155000
epoch is 48, train error 0.186667, test error 0.155000
epoch is 49, train error 0.186667, test error 0.155000
epoch is 50, train error 0.185000, test error 0.165000
epoch is 51, train error 0.185000, test error 0.160000
epoch is 52, train error 0.190000, test error 0.155000
epoch is 53, train error 0.188333, test error 0.155000
epoch is 54, train error 0.190000, test error 0.155000
epoch is 55, train error 0.190000, test error 0.155000
epoch is 56, train error 0.191667, test error 0.160000
epoch is 57, train error 0.190000, test error 0.160000
epoch is 58, train error 0.191667, test error 0.155000
epoch is 59, train error 0.191667, test error 0.160000
epoch is 60, train error 0.191667, test error 0.160000
epoch is 61, train error 0.190000, test error 0.155000
epoch is 62, train error 0.188333, test error 0.155000
epoch is 63, train error 0.188333, test error 0.155000
epoch is 64, train error 0.188333, test error 0.155000
epoch is 65, train error 0.188333, test error 0.155000
epoch is 66, train error 0.188333, test error 0.155000
epoch is 67, train error 0.188333, test error 0.155000
epoch is 68, train error 0.188333, test error 0.155000
epoch is 69, train error 0.186667, test error 0.155000
epoch is 70, train error 0.185000, test error 0.155000
epoch is 71, train error 0.185000, test error 0.155000
epoch is 72, train error 0.183333, test error 0.155000
epoch is 73, train error 0.181667, test error 0.155000
epoch is 74, train error 0.181667, test error 0.155000
epoch is 75, train error 0.181667, test error 0.155000
epoch is 76, train error 0.183333, test error 0.145000
epoch is 77, train error 0.183333, test error 0.145000
epoch is 78, train error 0.185000, test error 0.150000
epoch is 79, train error 0.185000, test error 0.150000
epoch is 80, train error 0.183333, test error 0.150000
epoch is 81, train error 0.183333, test error 0.150000
epoch is 82, train error 0.183333, test error 0.150000
epoch is 83, train error 0.181667, test error 0.150000
epoch is 84, train error 0.181667, test error 0.150000
epoch is 85, train error 0.181667, test error 0.150000
epoch is 86, train error 0.181667, test error 0.150000
epoch is 87, train error 0.180000, test error 0.150000
epoch is 88, train error 0.180000, test error 0.150000
epoch is 89, train error 0.180000, test error 0.150000
epoch is 90, train error 0.180000, test error 0.150000
epoch is 91, train error 0.178333, test error 0.150000
epoch is 92, train error 0.178333, test error 0.150000
epoch is 93, train error 0.180000, test error 0.150000
epoch is 94, train error 0.176667, test error 0.150000
epoch is 95, train error 0.176667, test error 0.145000
epoch is 96, train error 0.176667, test error 0.145000
epoch is 97, train error 0.176667, test error 0.145000
epoch is 98, train error 0.175000, test error 0.145000
epoch is 99, train error 0.175000, test error 0.145000
Course Testing
|
|
|
|
Hello World
|
|
|
|
|